{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "pf = pd.read_csv('data/FEdata.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "pf = pf.drop(['Over18','StandardHours'],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Age</th>\n",
       "      <th>Attrition</th>\n",
       "      <th>BusinessTravel</th>\n",
       "      <th>Department</th>\n",
       "      <th>DistanceFromHome</th>\n",
       "      <th>Education</th>\n",
       "      <th>EducationField</th>\n",
       "      <th>EmployeeNumber</th>\n",
       "      <th>EnvironmentSatisfaction</th>\n",
       "      <th>Gender</th>\n",
       "      <th>...</th>\n",
       "      <th>PerformanceRating</th>\n",
       "      <th>RelationshipSatisfaction</th>\n",
       "      <th>StockOptionLevel</th>\n",
       "      <th>TotalWorkingYears</th>\n",
       "      <th>TrainingTimesLastYear</th>\n",
       "      <th>WorkLifeBalance</th>\n",
       "      <th>YearsAtCompany</th>\n",
       "      <th>YearsInCurrentRole</th>\n",
       "      <th>YearsSinceLastPromotion</th>\n",
       "      <th>YearsWithCurrManager</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>37</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>77</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>7</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>7</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>54</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>1245</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>33</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>34</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>7</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>147</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>9</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>9</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>39</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1026</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>21</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>21</td>\n",
       "      <td>6</td>\n",
       "      <td>11</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>1111</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 29 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   Age  Attrition  BusinessTravel  Department  DistanceFromHome  Education  \\\n",
       "0   37          0               2           1                 1          4   \n",
       "1   54          0               1           1                 1          4   \n",
       "2   34          1               1           1                 7          3   \n",
       "3   39          0               2           1                 1          1   \n",
       "4   28          1               1           1                 1          3   \n",
       "\n",
       "   EducationField  EmployeeNumber  EnvironmentSatisfaction  Gender  \\\n",
       "0               1              77                        1       1   \n",
       "1               1            1245                        4       0   \n",
       "2               1             147                        1       1   \n",
       "3               1            1026                        4       0   \n",
       "4               3            1111                        1       1   \n",
       "\n",
       "           ...           PerformanceRating  RelationshipSatisfaction  \\\n",
       "0          ...                           3                         3   \n",
       "1          ...                           3                         1   \n",
       "2          ...                           4                         4   \n",
       "3          ...                           3                         3   \n",
       "4          ...                           3                         1   \n",
       "\n",
       "   StockOptionLevel  TotalWorkingYears  TrainingTimesLastYear  \\\n",
       "0                 1                  7                      2   \n",
       "1                 1                 33                      2   \n",
       "2                 0                  9                      3   \n",
       "3                 1                 21                      3   \n",
       "4                 2                  1                      2   \n",
       "\n",
       "   WorkLifeBalance  YearsAtCompany  YearsInCurrentRole  \\\n",
       "0                4               7                   5   \n",
       "1                1               5                   4   \n",
       "2                3               9                   7   \n",
       "3                3              21                   6   \n",
       "4                3               1                   0   \n",
       "\n",
       "   YearsSinceLastPromotion  YearsWithCurrManager  \n",
       "0                        0                     7  \n",
       "1                        1                     4  \n",
       "2                        0                     6  \n",
       "3                       11                     8  \n",
       "4                        0                     0  \n",
       "\n",
       "[5 rows x 29 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pf.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 1100 entries, 0 to 1099\n",
      "Data columns (total 29 columns):\n",
      "Age                         1100 non-null int64\n",
      "Attrition                   1100 non-null int64\n",
      "BusinessTravel              1100 non-null int64\n",
      "Department                  1100 non-null int64\n",
      "DistanceFromHome            1100 non-null int64\n",
      "Education                   1100 non-null int64\n",
      "EducationField              1100 non-null int64\n",
      "EmployeeNumber              1100 non-null int64\n",
      "EnvironmentSatisfaction     1100 non-null int64\n",
      "Gender                      1100 non-null int64\n",
      "JobInvolvement              1100 non-null int64\n",
      "JobLevel                    1100 non-null int64\n",
      "JobRole                     1100 non-null int64\n",
      "JobSatisfaction             1100 non-null int64\n",
      "MaritalStatus               1100 non-null int64\n",
      "MonthlyIncome               1100 non-null int64\n",
      "NumCompaniesWorked          1100 non-null int64\n",
      "OverTime                    1100 non-null int64\n",
      "PercentSalaryHike           1100 non-null int64\n",
      "PerformanceRating           1100 non-null int64\n",
      "RelationshipSatisfaction    1100 non-null int64\n",
      "StockOptionLevel            1100 non-null int64\n",
      "TotalWorkingYears           1100 non-null int64\n",
      "TrainingTimesLastYear       1100 non-null int64\n",
      "WorkLifeBalance             1100 non-null int64\n",
      "YearsAtCompany              1100 non-null int64\n",
      "YearsInCurrentRole          1100 non-null int64\n",
      "YearsSinceLastPromotion     1100 non-null int64\n",
      "YearsWithCurrManager        1100 non-null int64\n",
      "dtypes: int64(29)\n",
      "memory usage: 249.3 KB\n"
     ]
    }
   ],
   "source": [
    "pf.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train = torch.from_numpy(np.array(pf['Attrition']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'torch.LongTensor'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train.type()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1100])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train = y_train.reshape(1100,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train = y_train.float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "pf = pf.drop(['Attrition'],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "pfnp = np.array(pf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dtype('int64')"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pfnp.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.61666667, 1.        , 0.5       , ..., 0.27777778, 0.        ,\n",
       "        0.41176471],\n",
       "       [0.9       , 0.5       , 0.5       , ..., 0.22222222, 0.06666667,\n",
       "        0.23529412],\n",
       "       [0.56666667, 0.5       , 0.5       , ..., 0.38888889, 0.        ,\n",
       "        0.35294118],\n",
       "       ...,\n",
       "       [0.61666667, 1.        , 1.        , ..., 0.        , 0.        ,\n",
       "        0.        ],\n",
       "       [0.36666667, 1.        , 0.5       , ..., 0.        , 0.        ,\n",
       "        0.        ],\n",
       "       [0.43333333, 0.5       , 0.5       , ..., 0.11111111, 0.06666667,\n",
       "        0.11764706]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pfnp = pfnp/pfnp.max(axis=0)\n",
    "pfnp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = torch.from_numpy(pfnp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1100, 28])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'torch.FloatTensor'"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train = x_train.float()\n",
    "x_train.type()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.6167, 1.0000, 0.5000,  ..., 0.2778, 0.0000, 0.4118],\n",
       "        [0.9000, 0.5000, 0.5000,  ..., 0.2222, 0.0667, 0.2353],\n",
       "        [0.5667, 0.5000, 0.5000,  ..., 0.3889, 0.0000, 0.3529],\n",
       "        ...,\n",
       "        [0.6167, 1.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        [0.3667, 1.0000, 0.5000,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        [0.4333, 0.5000, 0.5000,  ..., 0.1111, 0.0667, 0.1176]])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Net(\n",
      "  (l1): Linear(in_features=28, out_features=128, bias=True)\n",
      "  (l2): Linear(in_features=128, out_features=32, bias=True)\n",
      "  (l3): Linear(in_features=32, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.l1 = nn.Linear(28,128)\n",
    "        self.l2 = nn.Linear(128,32)\n",
    "        self.l3 = nn.Linear(32,1)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.l1(x))\n",
    "        x = F.relu(self.l2(x))\n",
    "        x = torch.sigmoid(self.l3(x))\n",
    "        return x\n",
    "    \n",
    "net = Net()\n",
    "\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)\n",
    "loss_func = torch.nn.MSELoss() \n",
    "\n",
    "for i in range(1000):\n",
    "    out = net(x_train)\n",
    "    loss = loss_func(out,y_train)\n",
    "    optimizer.zero_grad()   # 清空上一步的残余更新参数值\n",
    "    loss.backward()         # 误差反向传播, 计算参数更新值\n",
    "    optimizer.step()        # 将参数更新值施加到 net 的 parameters 上"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([350, 28])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pf_pre = pd.read_csv('data/FEtest.csv')\n",
    "pf_pre = pf_pre.drop(['Over18','StandardHours'],axis=1)\n",
    "pfnp_pre = np.array(pf_pre)\n",
    "pfnp_pre = pfnp_pre/pfnp_pre.max(axis=0)\n",
    "x_pre = torch.from_numpy(pfnp_pre)\n",
    "x_pre.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_pre = x_pre.float()\n",
    "pre = net(x_pre)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pre"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "pre = pre.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "350"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(pre)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [1.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.],\n",
       "       [0.]], dtype=float32)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for x in range(len(pre)):\n",
    "    if pre[x]>0.5:\n",
    "        pre[x]=1\n",
    "    else:\n",
    "        pre[x]=0\n",
    "pre"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = pd.DataFrame(pre,dtype='int64')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>320</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>321</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>322</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>323</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>324</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>325</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>326</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>327</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>328</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>329</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>330</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>331</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>332</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>333</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>334</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>335</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>336</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>337</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>338</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>339</th>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>340</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>341</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>342</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>343</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>344</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>345</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>346</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>347</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>348</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>349</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>350 rows × 1 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     0\n",
       "0    0\n",
       "1    0\n",
       "2    0\n",
       "3    0\n",
       "4    0\n",
       "5    0\n",
       "6    0\n",
       "7    0\n",
       "8    0\n",
       "9    0\n",
       "10   0\n",
       "11   1\n",
       "12   1\n",
       "13   0\n",
       "14   0\n",
       "15   0\n",
       "16   0\n",
       "17   0\n",
       "18   0\n",
       "19   0\n",
       "20   0\n",
       "21   0\n",
       "22   0\n",
       "23   0\n",
       "24   1\n",
       "25   0\n",
       "26   0\n",
       "27   0\n",
       "28   0\n",
       "29   0\n",
       "..  ..\n",
       "320  0\n",
       "321  0\n",
       "322  0\n",
       "323  0\n",
       "324  0\n",
       "325  0\n",
       "326  0\n",
       "327  0\n",
       "328  0\n",
       "329  0\n",
       "330  0\n",
       "331  0\n",
       "332  0\n",
       "333  0\n",
       "334  0\n",
       "335  0\n",
       "336  0\n",
       "337  0\n",
       "338  0\n",
       "339  1\n",
       "340  0\n",
       "341  0\n",
       "342  0\n",
       "343  0\n",
       "344  0\n",
       "345  0\n",
       "346  0\n",
       "347  0\n",
       "348  0\n",
       "349  0\n",
       "\n",
       "[350 rows x 1 columns]"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.head(1100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "out.columns=['result']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>result</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   result\n",
       "0       0\n",
       "1       0\n",
       "2       0\n",
       "3       0\n",
       "4       0"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "out.to_csv('result2.csv',index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Attrition                   1.000000\n",
       "OverTime                    0.267080\n",
       "DistanceFromHome            0.088563\n",
       "JobRole                     0.086899\n",
       "EducationField              0.055239\n",
       "Department                  0.053364\n",
       "PerformanceRating           0.046762\n",
       "PercentSalaryHike           0.026604\n",
       "NumCompaniesWorked          0.025889\n",
       "Gender                      0.016750\n",
       "BusinessTravel              0.015483\n",
       "TrainingTimesLastYear      -0.043395\n",
       "EmployeeNumber             -0.045168\n",
       "Education                  -0.046494\n",
       "WorkLifeBalance            -0.048794\n",
       "RelationshipSatisfaction   -0.051749\n",
       "YearsSinceLastPromotion    -0.071760\n",
       "EnvironmentSatisfaction    -0.097003\n",
       "JobInvolvement             -0.122722\n",
       "JobSatisfaction            -0.125568\n",
       "StockOptionLevel           -0.138498\n",
       "YearsAtCompany             -0.143697\n",
       "MaritalStatus              -0.153045\n",
       "MonthlyIncome              -0.155521\n",
       "YearsWithCurrManager       -0.158558\n",
       "YearsInCurrentRole         -0.163059\n",
       "JobLevel                   -0.168775\n",
       "Age                        -0.175393\n",
       "TotalWorkingYears          -0.187922\n",
       "Over18                           NaN\n",
       "StandardHours                    NaN\n",
       "Name: Attrition, dtype: float64"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "corrDf = pf.corr()\n",
    "corrDf['Attrition'].sort_values(ascending =False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
